from Criterion import SoftmaxLoss
from main import *
import os,sys
import argparse
"""
实验说明: 复现L-softmax.
https://github.com/wy1iu/LargeMargin_Softmax_Loss/blob/master/myexamples/cifar10/cifar_solver.prototxt
"""
from models import cnn

fname = os.path.basename(__file__)[:-3]


def parse_arg():
    parser = argparse.ArgumentParser(description='Series tool for ML-factors search.')
    parser.add_argument('-e','--epochs',default=100,type=int)
    parser.add_argument('--lr',default=0.1,type=float)
    parser.add_argument('-o','--out',type=str)
    parser.add_argument('--data',default='./data')
    parser.add_argument('--ckp',default='')
    parser.add_argument('--device', default='cuda')

    parser.add_argument('-s', '--series', nargs='*', help='对照变量', required=True)
    return parser.parse_known_args()

if __name__ == '__main__':

    arg,_unknow = parse_arg()

    global_state = {'lr':arg.lr,
                    'process':0
                    }

    if arg.ckp != '':
        states = torch.load(arg.ckp)
        global_state['lr'] = states['lr']
        global_state['process'] = states['process']


    device = torch.device(arg.device)
    train_loader, test_loader = gen_data(arg.data)

    out = arg.out
    try:
        os.makedirs(out+'/archs')
    except:
        pass

    print(arg.series)
    series = [ eval(t) for t in arg.series]
    

    summary = open(f'{out}/summary.txt', 'a+')
    if global_state['process']==0:
        lines = [
        '-------------arguments--------',
        ' '.join(sys.argv),
        '-------------result--------']
        summary.writelines([ line+'\n' for line in lines])
        summary.flush()
    log = open(f'{out}/log.txt', 'a+')

    setup_seed(123)
    for i in range(global_state['process'],len(series)):
        variable = series[i]
        best_acc = 0
        lr = arg.lr
        # init
        model = cnn(256).to(device)
        loss_func = SoftmaxLoss(256,10).to(device)
        optimizer = optim.SGD(model.parameters(),lr,momentum=0.9)
        optimizer.add_param_group({'lr':lr,'params':loss_func.parameters()})

        # train
        log.write(f'/*-------i:{i},{variable}----------*/')
        epochs = variable
        for epoch in range(epochs):
            if epoch in [int(epochs * 0.5), int(epochs * 0.7), int(epochs * 0.85)]:
                lr /= 10
                # 修改学习率
                for p in optimizer.param_groups:
                    p['lr'] = lr

            train_loss = train(train_loader,model,optimizer,loss_func,epoch,device)
            test_loss, acc = test(test_loader, model, loss_func, epoch, device)

            log.writelines([
                '{:03d} {:6.4f}  {:6.4f} {:6.4f}\n'.format(epoch, train_loss, test_loss, acc),
            ])
            log.flush()

            # 取得更高的精度
            if acc > best_acc:
                save_arch(model,loss_func,optimizer,f'{out}/archs/best_{i}.pth')
                best_acc = acc
        # 一次试验结束
        global_state['process']=i
        torch.save(global_state,out+'/ckp.pth')

        summary.write('{} {:6.4f}\n'
                      .format(variable, best_acc,))

    else:
        raise NotImplementedError()
